{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import json\n",
    "import subprocess\n",
    "from scipy.misc import imread\n",
    "%matplotlib inline\n",
    "\n",
    "from train import build_forward\n",
    "from utils import googlenet_load, train_utils\n",
    "from utils.annolist import AnnotationLib as al\n",
    "from utils.stitch_wrapper import stitch_rects\n",
    "from utils.train_utils import add_rectangles\n",
    "from utils.rect import Rect\n",
    "from utils.stitch_wrapper import stitch_rects\n",
    "from evaluate import add_rectangles\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "hypes_file = './hypes/overfeat_rezoom.json'\n",
    "iteration = 150000\n",
    "with open(hypes_file, 'r') as f:\n",
    "    H = json.load(f)\n",
    "true_idl = './data/brainwash/brainwash_val.idl'\n",
    "pred_idl = './output/%d_val_%s.idl' % (iteration, os.path.basename(hypes_file).replace('.json', ''))\n",
    "true_annos = al.parse(true_idl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "googlenet = googlenet_load.init(H)\n",
    "x_in = tf.placeholder(tf.float32, name='x_in', shape=[H['image_height'], H['image_width'], 3])\n",
    "if H['use_rezoom']:\n",
    "    pred_boxes, pred_logits, pred_confidences, pred_confs_deltas, pred_boxes_deltas = build_forward(H, tf.expand_dims(x_in, 0), googlenet, 'test', reuse=None)\n",
    "    grid_area = H['grid_height'] * H['grid_width']\n",
    "    pred_confidences = tf.reshape(tf.nn.softmax(tf.reshape(pred_confs_deltas, [grid_area * H['rnn_len'], 2])), [grid_area, H['rnn_len'], 2])\n",
    "    if H['reregress']:\n",
    "        pred_boxes = pred_boxes + pred_boxes_deltas\n",
    "else:\n",
    "    pred_boxes, pred_logits, pred_confidences = build_forward(H, tf.expand_dims(x_in, 0), googlenet, 'test', reuse=None)\n",
    "saver = tf.train.Saver()\n",
    "with tf.Session() as sess:\n",
    "    sess.run(tf.initialize_all_variables())\n",
    "    saver.restore(sess, './data/overfeat_rezoom/save.ckpt-%d' % iteration)\n",
    "\n",
    "    annolist = al.AnnoList()\n",
    "    import time; t = time.time()\n",
    "    for i in range(0, 500):\n",
    "        true_anno = true_annos[i]\n",
    "        img = imread('./data/brainwash/%s' % true_anno.imageName)\n",
    "        feed = {x_in: img}\n",
    "        (np_pred_boxes, np_pred_confidences) = sess.run([pred_boxes, pred_confidences], feed_dict=feed)\n",
    "        pred_anno = al.Annotation()\n",
    "        pred_anno.imageName = true_anno.imageName\n",
    "        new_img, rects = add_rectangles(H, [img], np_pred_confidences, np_pred_boxes,\n",
    "                                        use_stitching=True, rnn_len=H['rnn_len'], min_conf=0.3)\n",
    "    \n",
    "        pred_anno.rects = rects\n",
    "        annolist.append(pred_anno)\n",
    "\n",
    "        if i % 10 == 0 and i < 200:\n",
    "            pass\n",
    "            fig = plt.figure(figsize=(12, 12))\n",
    "            plt.imshow(new_img)\n",
    "        if i % 100 == 0:\n",
    "            print(i)\n",
    "    avg_time = (time.time() - t) / (i + 1)\n",
    "    print('%f images/sec' % (1. / avg_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "annolist.save(pred_idl)\n",
    "iou_threshold = 0.5\n",
    "rpc_cmd = './utils/annolist/doRPC.py --minOverlap %f %s %s' % (iou_threshold, true_idl, pred_idl)\n",
    "print('$ %s' % rpc_cmd)\n",
    "rpc_output = subprocess.check_output(rpc_cmd, shell=True)\n",
    "print(rpc_output)\n",
    "txt_file = [line for line in rpc_output.split('\\n') if line.strip()][-1]\n",
    "output_png = 'output/results.png'\n",
    "plot_cmd = './utils/annolist/plotSimple.py %s --output %s' % (txt_file, output_png)\n",
    "print('$ %s' % plot_cmd)\n",
    "plot_output = subprocess.check_output(plot_cmd, shell=True)\n",
    "from IPython.display import Image\n",
    "Image(filename=output_png) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
